#!/usr/bin/env python

import os
import numpy as np
import pandas as pd
from astropy.cosmology import Planck15 as cosmo
import astropy.units as u

INPUT_CSV = os.path.join("data", "KiDS_DR5_lenssample.csv")
OUTPUT_CSV = os.path.join("data", "lenses_dr5.csv")

MSTAR_EDGES = [10.2, 10.5, 10.8, 11.1]
MSTAR_LABELS = ["10.2–10.5", "10.5–10.8", "10.8–11.1"]

RG_EDGES = [1.5, 3.0, 5.0, 8.0, 12.0]
RG_LABELS = ["1.5–3.0", "3.0–5.0", "5.0–8.0", "8.0–12.0"]


def compute_R_G_kpc(df: pd.DataFrame) -> pd.Series:
    a = df["A_WORLD"].values  # degrees
    b = df["B_WORLD"].values  # degrees
    size_arcsec = 3600.0 * np.sqrt(a * b)
    theta_rad = np.deg2rad(size_arcsec / 3600.0)

    z = df["Z_B"].values
    z = np.where((z > 0) & np.isfinite(z), z, np.nan)

    DA = cosmo.angular_diameter_distance(z * u.dimensionless_unscaled).to(u.kpc).value
    return pd.Series(theta_rad * DA, index=df.index)


def assign_bins(values: pd.Series, edges, labels) -> pd.Series:
    return pd.cut(values, bins=edges, labels=labels, right=False)


def main():
    if not os.path.exists(INPUT_CSV):
        raise SystemExit(f"Input CSV not found: {INPUT_CSV}")

    print("[info] Reading", INPUT_CSV)
    df = pd.read_csv(INPUT_CSV)

    for c in ["ID", "RAJ2000", "DECJ2000", "Z_B",
              "mstar_med", "mstar_bestfit", "A_WORLD", "B_WORLD"]:
        if c not in df.columns:
            raise SystemExit(f"Missing required column '{c}' in DR5 CSV.")

    m_med = df["mstar_med"].values
    m_best = df["mstar_bestfit"].values
    m_med_valid = (m_med > 0) & np.isfinite(m_med)
    m_best_valid = (m_best > 0) & np.isfinite(m_best)
    mstar_log10 = np.where(m_med_valid, m_med,
                           np.where(m_best_valid, m_best, np.nan))
    df["Mstar_log10"] = mstar_log10

    mask = (
        np.isfinite(df["Mstar_log10"].values)
        & (df["Mstar_log10"].values > 9.0)
        & (df["Mstar_log10"].values < 12.5)
        & (df["Z_B"].values > 0.0)
        & (df["Z_B"].values < 2.0)
        & np.isfinite(df["A_WORLD"].values)
        & np.isfinite(df["B_WORLD"].values)
        & (df["A_WORLD"].values > 0.0)
        & (df["B_WORLD"].values > 0.0)
    )
    df = df.loc[mask].copy()
    print(f"[info] Rows after basic cuts: {len(df)}")

    print("[info] Computing R_G_kpc...")
    df["R_G_kpc"] = compute_R_G_kpc(df)
    df = df[(df["R_G_kpc"] > 0) & np.isfinite(df["R_G_kpc"])].copy()
    print(f"[info] Rows after size filter: {len(df)}")

    print("[info] Assigning mass and size bins...")
    df["Mstar_bin"] = assign_bins(df["Mstar_log10"], MSTAR_EDGES, MSTAR_LABELS)
    df["R_G_bin"] = assign_bins(df["R_G_kpc"], RG_EDGES, RG_LABELS)
    df = df[df["Mstar_bin"].notna() & df["R_G_bin"].notna()].copy()
    print(f"[info] Rows inside desired mass+size bins: {len(df)}")

    out = pd.DataFrame(
        {
            "lens_id": df["ID"],
            "ra_deg": df["RAJ2000"],
            "dec_deg": df["DECJ2000"],
            "z_lens": df["Z_B"],
            "R_G_kpc": df["R_G_kpc"],
            "Mstar_log10": df["Mstar_log10"],
            "R_G_bin": df["R_G_bin"].astype(str),
            "Mstar_bin": df["Mstar_bin"].astype(str),
        }
    )

    os.makedirs(os.path.dirname(OUTPUT_CSV), exist_ok=True)
    print("[info] Writing lens table to", OUTPUT_CSV)
    out.to_csv(OUTPUT_CSV, index=False)
    print("[info] Done. Wrote", len(out), "lenses.")


if __name__ == "__main__":
    main()
